/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package main;

/**
 *
 * @author ZSQ
 */

/**
 * 均方差损失函数实现
 * 由架构师实现
 */

public class MeanSquaredError implements LossFunction {
    @Override
    public double calculate(double[] predicted, double[] actual) {
        double sum = 0.0;
        for (int i = 0; i < predicted.length; i++) {
            double error = predicted[i] - actual[i];
            sum += error * error;
        }
        return sum / predicted.length;
    }
    
    @Override
    public double[] derivative(double[] predicted, double[] actual) {
        double[] derivatives = new double[predicted.length];
        for (int i = 0; i < predicted.length; i++) {
            derivatives[i] = 2.0 * (predicted[i] - actual[i]) / predicted.length;
        }
        return derivatives;
    }
}
